{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reproducibility Tutorial\n",
    "\n",
    "Reproducibility of ML models and evaluations is frequently a problem across many ML systems. It's usually two problems, the first is a description of the computation that was executed, and the second is a method of replaying that computation. In Tribuo we built our provenance system to make our models *self-describing* by which we mean they capture a complete description of the computation that produced them, solving the first issue. In v4.2 we added an automated reproducibility system which consumes the provenance data and retrains the model. As well as the reproducibility system we also added a mechanism for diffing provenance objects allowing easy comparison between the reproduced and original models. This is because the models are only guaranteed to be identical if the data is the same, and any differences in the data will show up in the data provenance object.\n",
    "\n",
    "## Setup\n",
    "\n",
    "Before running this tutorial, please run the irises classification and ONNX export tutorial to build the two models that we're going to reproduce.\n",
    "\n",
    "We're going to load in the classification jar, onnx jar, and the reproducibility jar. Note the reproducibility jar is written in Java 16, and so this tutorial requires Java 16 or later. Then we'll import the necessary classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%jars ./tribuo-classification-experiments-4.3.0-jar-with-dependencies.jar\n",
    "%jars ./tribuo-onnx-4.3.0-jar-with-dependencies.jar\n",
    "%jars ./tribuo-json-4.3.0-jar-with-dependencies.jar\n",
    "%jars ./tribuo-reproducibility-4.3.0-jar-with-dependencies.jar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.tribuo.*;\n",
    "import org.tribuo.classification.*;\n",
    "import org.tribuo.classification.evaluation.*;\n",
    "import org.tribuo.classification.sgd.fm.*;\n",
    "import org.tribuo.classification.sgd.linear.*;\n",
    "import org.tribuo.datasource.*;\n",
    "import org.tribuo.interop.onnx.*;\n",
    "import org.tribuo.reproducibility.*;\n",
    "import com.oracle.labs.mlrg.olcut.provenance.*;\n",
    "import com.oracle.labs.mlrg.olcut.util.*;\n",
    "import ai.onnxruntime.*;\n",
    "\n",
    "import java.nio.file.*;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reproducing a Tribuo Model\n",
    "\n",
    "The reproducibility system works on Tribuo `Model` or `ModelProvenance` objects. When using the `ModelProvenance` the system loads in the original training data, processes and transforms it according to the columnar processing and transforms applied, then rebuilds the original trainer including it's RNG state, before passing the data into the train method and returning the reproduced model. When using the `Model` object, it performs the same steps as for a `ModelProvenance` and then compares the feature and output domains to provide more information about any differences between the feature and output domains used by the model. Over time we plan to expand the validation applied to the reproduced model to show if the features have different ranges or histograms.\n",
    "\n",
    "We're going to load in the Irises logistic regression model trained in the first tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linear-sgd-model - Model(class-name=org.tribuo.classification.sgd.linear.LinearSGDModel,dataset=Dataset(class-name=org.tribuo.MutableDataset,datasource=SplitDataSourceProvenance(className=org.tribuo.evaluation.TrainTestSplitter,innerSourceProvenance=DataSource(class-name=org.tribuo.data.csv.CSVDataSource,headers=[sepalLength, sepalWidth, petalLength, petalWidth, species],rowProcessor=RowProcessor(class-name=org.tribuo.data.columnar.RowProcessor,metadataExtractors=[],fieldProcessorList=[FieldProcessor(class-name=org.tribuo.data.columnar.processors.field.DoubleFieldProcessor,fieldName=petalLength,onlyFieldName=true,throwOnInvalid=true,host-short-name=FieldProcessor), FieldProcessor(class-name=org.tribuo.data.columnar.processors.field.DoubleFieldProcessor,fieldName=petalWidth,onlyFieldName=true,throwOnInvalid=true,host-short-name=FieldProcessor), FieldProcessor(class-name=org.tribuo.data.columnar.processors.field.DoubleFieldProcessor,fieldName=sepalWidth,onlyFieldName=true,throwOnInvalid=true,host-short-name=FieldProcessor), FieldProcessor(class-name=org.tribuo.data.columnar.processors.field.DoubleFieldProcessor,fieldName=sepalLength,onlyFieldName=true,throwOnInvalid=true,host-short-name=FieldProcessor)],featureProcessors=[],responseProcessor=ResponseProcessor(class-name=org.tribuo.data.columnar.processors.response.FieldResponseProcessor,uppercase=false,fieldNames=[species],defaultValues=[],displayField=false,outputFactory=OutputFactory(class-name=org.tribuo.classification.LabelFactory),host-short-name=ResponseProcessor),weightExtractor=null,replaceNewlinesWithSpaces=true,regexMappingProcessors={},host-short-name=RowProcessor),quote=\",outputRequired=true,outputFactory=OutputFactory(class-name=org.tribuo.classification.LabelFactory),separator=,,dataPath=/local/ExternalRepositories/tribuo/tutorials/bezdekIris.data,resource-hash=SHA-256[0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC],file-modified-time=1999-12-14T15:12:39-05:00,datasource-creation-time=2022-10-07T11:20:06.279351-04:00,host-short-name=DataSource),trainProportion=0.7,seed=1,size=150,isTrain=true),transformations=[],is-sequence=false,is-dense=true,num-examples=105,num-features=4,num-outputs=3,tribuo-version=4.3.0),trainer=Trainer(class-name=org.tribuo.classification.sgd.linear.LogisticRegressionTrainer,seed=12345,minibatchSize=1,shuffle=true,epochs=5,optimiser=StochasticGradientOptimiser(class-name=org.tribuo.math.optimisers.AdaGrad,epsilon=0.1,initialLearningRate=1.0,initialValue=0.0,host-short-name=StochasticGradientOptimiser),loggingInterval=1000,objective=LabelObjective(class-name=org.tribuo.classification.sgd.objectives.LogMulticlass,host-short-name=LabelObjective),tribuo-version=4.3.0,train-invocation-count=0,is-sequence=false,host-short-name=Trainer),trained-at=2022-10-07T11:20:06.643297-04:00,instance-values={},tribuo-version=4.3.0,java-version=12,os-name=Linux,os-arch=amd64)\n"
     ]
    }
   ],
   "source": [
    "File irisModelFile = new File(\"iris-lr-model.ser\");\n",
    "String filterPattern = Files.readAllLines(Paths.get(\"../docs/jep-290-filter.txt\")).get(0);\n",
    "ObjectInputFilter filter = ObjectInputFilter.Config.createFilter(filterPattern);\n",
    "LinearSGDModel loadedModel;\n",
    "try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(irisModelFile)))) {\n",
    "    ois.setObjectInputFilter(filter);\n",
    "    loadedModel = (LinearSGDModel) ois.readObject();\n",
    "}\n",
    "\n",
    "System.out.println(loadedModel.toString());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The reproducibility system lives in the `ReproUtil` class. This class is constructed with a `Model` or a `ModelProvenance` and `Class<T extends Output<T>>` for the output class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "var repro = new ReproUtil<>(loadedModel);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can separately rebuild the dataset and the trainer, though note if you mutate the objects returned by these methods then you won't get the exact same model back from the reproduction. We're still working on the API for the reproducibility system and expect to make this API more robust over time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MutableDataset(\n",
      "\tclass-name = org.tribuo.MutableDataset\n",
      "\tdatasource = TrainTestSplitter(\n",
      "\t\t\tclass-name = org.tribuo.evaluation.TrainTestSplitter\n",
      "\t\t\tsource = CSVDataSource(\n",
      "\t\t\t\t\tclass-name = org.tribuo.data.csv.CSVDataSource\n",
      "\t\t\t\t\theaders = List[\n",
      "\t\t\t\t\t\tsepalLength\n",
      "\t\t\t\t\t\tsepalWidth\n",
      "\t\t\t\t\t\tpetalLength\n",
      "\t\t\t\t\t\tpetalWidth\n",
      "\t\t\t\t\t\tspecies\n",
      "\t\t\t\t\t]\n",
      "\t\t\t\t\trowProcessor = RowProcessor(\n",
      "\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.RowProcessor\n",
      "\t\t\t\t\t\t\tmetadataExtractors = List[]\n",
      "\t\t\t\t\t\t\tfieldProcessorList = List[\n",
      "\t\t\t\t\t\t\t\tDoubleFieldProcessor(\n",
      "\t\t\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\n",
      "\t\t\t\t\t\t\t\t\t\t\tfieldName = petalLength\n",
      "\t\t\t\t\t\t\t\t\t\t\tonlyFieldName = true\n",
      "\t\t\t\t\t\t\t\t\t\t\tthrowOnInvalid = true\n",
      "\t\t\t\t\t\t\t\t\t\t\thost-short-name = FieldProcessor\n",
      "\t\t\t\t\t\t\t\t\t\t)\n",
      "\t\t\t\t\t\t\t\tDoubleFieldProcessor(\n",
      "\t\t\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\n",
      "\t\t\t\t\t\t\t\t\t\t\tfieldName = petalWidth\n",
      "\t\t\t\t\t\t\t\t\t\t\tonlyFieldName = true\n",
      "\t\t\t\t\t\t\t\t\t\t\tthrowOnInvalid = true\n",
      "\t\t\t\t\t\t\t\t\t\t\thost-short-name = FieldProcessor\n",
      "\t\t\t\t\t\t\t\t\t\t)\n",
      "\t\t\t\t\t\t\t\tDoubleFieldProcessor(\n",
      "\t\t\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\n",
      "\t\t\t\t\t\t\t\t\t\t\tfieldName = sepalWidth\n",
      "\t\t\t\t\t\t\t\t\t\t\tonlyFieldName = true\n",
      "\t\t\t\t\t\t\t\t\t\t\tthrowOnInvalid = true\n",
      "\t\t\t\t\t\t\t\t\t\t\thost-short-name = FieldProcessor\n",
      "\t\t\t\t\t\t\t\t\t\t)\n",
      "\t\t\t\t\t\t\t\tDoubleFieldProcessor(\n",
      "\t\t\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\n",
      "\t\t\t\t\t\t\t\t\t\t\tfieldName = sepalLength\n",
      "\t\t\t\t\t\t\t\t\t\t\tonlyFieldName = true\n",
      "\t\t\t\t\t\t\t\t\t\t\tthrowOnInvalid = true\n",
      "\t\t\t\t\t\t\t\t\t\t\thost-short-name = FieldProcessor\n",
      "\t\t\t\t\t\t\t\t\t\t)\n",
      "\t\t\t\t\t\t\t]\n",
      "\t\t\t\t\t\t\tfeatureProcessors = List[]\n",
      "\t\t\t\t\t\t\tresponseProcessor = FieldResponseProcessor(\n",
      "\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.processors.response.FieldResponseProcessor\n",
      "\t\t\t\t\t\t\t\t\tuppercase = false\n",
      "\t\t\t\t\t\t\t\t\tfieldNames = List[\n",
      "\t\t\t\t\t\t\t\t\t\tspecies\n",
      "\t\t\t\t\t\t\t\t\t]\n",
      "\t\t\t\t\t\t\t\t\tdefaultValues = List[\n",
      "\t\t\t\t\t\t\t\t\t\t\n",
      "\t\t\t\t\t\t\t\t\t]\n",
      "\t\t\t\t\t\t\t\t\tdisplayField = false\n",
      "\t\t\t\t\t\t\t\t\toutputFactory = LabelFactory(\n",
      "\t\t\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.classification.LabelFactory\n",
      "\t\t\t\t\t\t\t\t\t\t)\n",
      "\t\t\t\t\t\t\t\t\thost-short-name = ResponseProcessor\n",
      "\t\t\t\t\t\t\t\t)\n",
      "\t\t\t\t\t\t\tweightExtractor = FieldExtractor(\n",
      "\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.FieldExtractor\n",
      "\t\t\t\t\t\t\t\t)\n",
      "\t\t\t\t\t\t\treplaceNewlinesWithSpaces = true\n",
      "\t\t\t\t\t\t\tregexMappingProcessors = Map{}\n",
      "\t\t\t\t\t\t\thost-short-name = RowProcessor\n",
      "\t\t\t\t\t\t)\n",
      "\t\t\t\t\tquote = \"\n",
      "\t\t\t\t\toutputRequired = true\n",
      "\t\t\t\t\toutputFactory = LabelFactory(\n",
      "\t\t\t\t\t\t\tclass-name = org.tribuo.classification.LabelFactory\n",
      "\t\t\t\t\t\t)\n",
      "\t\t\t\t\tseparator = ,\n",
      "\t\t\t\t\tdataPath = /local/ExternalRepositories/tribuo/tutorials/bezdekIris.data\n",
      "\t\t\t\t\tresource-hash = 0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC\n",
      "\t\t\t\t\tfile-modified-time = 1999-12-14T15:12:39-05:00\n",
      "\t\t\t\t\tdatasource-creation-time = 2022-10-07T12:03:48.921236415-04:00\n",
      "\t\t\t\t\thost-short-name = DataSource\n",
      "\t\t\t\t)\n",
      "\t\t\ttrain-proportion = 0.7\n",
      "\t\t\tseed = 1\n",
      "\t\t\tsize = 150\n",
      "\t\t\tis-train = true\n",
      "\t\t)\n",
      "\ttransformations = List[]\n",
      "\tis-sequence = false\n",
      "\tis-dense = true\n",
      "\tnum-examples = 105\n",
      "\tnum-features = 4\n",
      "\tnum-outputs = 3\n",
      "\ttribuo-version = 4.3.0\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "var dataset = repro.recoverDataset();\n",
    "\n",
    "System.out.println(ProvenanceUtil.formattedProvenanceString(dataset.getProvenance()));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our irises dataset was loaded in using the `CSVLoader` and split with a 70/30 train test split, and we can see that the reproduced training dataset has been split just as we expect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LogisticRegressionTrainer(\n",
      "\tclass-name = org.tribuo.classification.sgd.linear.LogisticRegressionTrainer\n",
      "\tseed = 12345\n",
      "\tminibatchSize = 1\n",
      "\tshuffle = true\n",
      "\tepochs = 5\n",
      "\toptimiser = AdaGrad(\n",
      "\t\t\tclass-name = org.tribuo.math.optimisers.AdaGrad\n",
      "\t\t\tepsilon = 0.1\n",
      "\t\t\tinitialLearningRate = 1.0\n",
      "\t\t\tinitialValue = 0.0\n",
      "\t\t\thost-short-name = StochasticGradientOptimiser\n",
      "\t\t)\n",
      "\tloggingInterval = 1000\n",
      "\tobjective = LogMulticlass(\n",
      "\t\t\tclass-name = org.tribuo.classification.sgd.objectives.LogMulticlass\n",
      "\t\t\thost-short-name = LabelObjective\n",
      "\t\t)\n",
      "\ttribuo-version = 4.3.0\n",
      "\ttrain-invocation-count = 0\n",
      "\tis-sequence = false\n",
      "\thost-short-name = Trainer\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "var trainer = repro.recoverTrainer();\n",
    "System.out.println(ProvenanceUtil.formattedProvenanceString(trainer.getProvenance()));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The irises model is a logistic regression, using seed `12345` and it's the first model trained by that trainer (as `train-invocation-count` is zero)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "var reproduction = repro.reproduceFromModel();\n",
    "var reproducedModel = (LinearSGDModel) reproduction.model();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can compare this provenance to the one in the original model using our diff tool, however as Tribuo records construction timestamps they will not be identical."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"dataset\" : {\n",
      "    \"datasource\" : {\n",
      "      \"source\" : {\n",
      "        \"datasource-creation-time\" : {\n",
      "          \"original\" : \"2022-10-07T11:20:06.279351-04:00\",\n",
      "          \"reproduced\" : \"2022-10-07T12:03:48.921236415-04:00\"\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  },\n",
      "  \"java-version\" : {\n",
      "    \"original\" : \"12\",\n",
      "    \"reproduced\" : \"17.0.4.1\"\n",
      "  },\n",
      "  \"trained-at\" : {\n",
      "    \"original\" : \"2022-10-07T11:20:06.643297-04:00\",\n",
      "    \"reproduced\" : \"2022-10-07T12:03:49.150931420-04:00\"\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "System.out.println(ReproUtil.diffProvenance(loadedModel.getProvenance(),reproducedModel.getProvenance()));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the timestamps are a little different, though the precise difference will depend on when you ran the irises tutorial. You may also see differences in the JVM or other machine provenance if you ran that tutorial on a different machine. If the irises dataset grows a new feature or additional rows in the same file, then the diff will show that the datasets have different numbers of features or samples, and that the file has a different hash.\n",
    "\n",
    "For some models we can easily compare the model contents, e.g., for the logistic regression we can directly compare the model weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights are equal = true\n"
     ]
    }
   ],
   "source": [
    "var originalWeights = loadedModel.getWeightsCopy();\n",
    "var reproducedWeights = reproducedModel.getWeightsCopy();\n",
    "\n",
    "System.out.println(\"Weights are equal = \" + originalWeights.equals(reproducedWeights));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reproducing an ONNX exported Tribuo Model\n",
    "\n",
    "Tribuo models can be exported into the [ONNX](https://onnx.ai) format. When Tribuo models are exported the model provenance is stored as a metadata field in the ONNX file. This doesn't affect anything which serves the ONNX model, but allows Tribuo to load the provenance back in if the model is loaded in as an `ONNXExternalModel` which is Tribuo's class for loading in ONNX models.\n",
    "\n",
    "To load a model in as an `ONNXExternalModel` we need to define the feature and label mappings which should be written out separately when the ONNX model is exported. We're going to cheat slightly and get them from the MNIST training set itself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "var labelFactory = new LabelFactory();\n",
    "var mnistTrainSource = new IDXDataSource<>(Paths.get(\"train-images-idx3-ubyte.gz\"),Paths.get(\"train-labels-idx1-ubyte.gz\"),labelFactory);\n",
    "var mnistTestSource = new IDXDataSource<>(Paths.get(\"t10k-images-idx3-ubyte.gz\"),Paths.get(\"t10k-labels-idx1-ubyte.gz\"),labelFactory);\n",
    "var mnistTrain = new MutableDataset<>(mnistTrainSource);\n",
    "var mnistTest = new MutableDataset<>(mnistTestSource);\n",
    "\n",
    "Map<String, Integer> mnistFeatureMap = new HashMap<>();\n",
    "for (VariableInfo f : mnistTrain.getFeatureIDMap()){\n",
    "    VariableIDInfo id = (VariableIDInfo) f;\n",
    "    mnistFeatureMap.put(id.getName(),id.getID());\n",
    "}\n",
    "Map<Label, Integer> mnistOutputMap = new HashMap<>();\n",
    "for (Pair<Integer,Label> l : mnistTrain.getOutputIDInfo()) {\n",
    "    mnistOutputMap.put(l.getB(), l.getA());\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's load in the ONNX file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "var ortEnv = OrtEnvironment.getEnvironment();\n",
    "var sessionOpts = new OrtSession.SessionOptions();\n",
    "var denseTransformer = new DenseTransformer();\n",
    "var labelTransformer = new LabelTransformer();\n",
    "var mnistModelPath = Paths.get(\".\",\"fm-mnist.onnx\");\n",
    "ONNXExternalModel<Label> onnx = ONNXExternalModel.createOnnxModel(labelFactory, mnistFeatureMap, mnistOutputMap,\n",
    "                    denseTransformer, labelTransformer, sessionOpts, mnistModelPath, \"input\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This model has two provenance objects, one from the creation of the `ONNXExternalModel`, and one from the original training run in Tribuo which is persisted inside the ONNX file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ONNXExternalModel(\n",
      "\tclass-name = org.tribuo.interop.onnx.ONNXExternalModel\n",
      "\tdataset = Dataset(\n",
      "\t\t\tclass-name = org.tribuo.Dataset\n",
      "\t\t\tdatasource = DataSource(\n",
      "\t\t\t\t\tdescription = unknown-external-data\n",
      "\t\t\t\t\toutputFactory = LabelFactory(\n",
      "\t\t\t\t\t\t\tclass-name = org.tribuo.classification.LabelFactory\n",
      "\t\t\t\t\t\t)\n",
      "\t\t\t\t\tdatasource-creation-time = 2022-10-07T12:03:57.351723125-04:00\n",
      "\t\t\t\t)\n",
      "\t\t\ttransformations = List[]\n",
      "\t\t\tis-sequence = false\n",
      "\t\t\tis-dense = false\n",
      "\t\t\tnum-examples = -1\n",
      "\t\t\tnum-features = 717\n",
      "\t\t\tnum-outputs = 10\n",
      "\t\t\ttribuo-version = 4.3.0\n",
      "\t\t)\n",
      "\ttrainer = Trainer(\n",
      "\t\t\tclass-name = org.tribuo.Trainer\n",
      "\t\t\tfileModifiedTime = 2022-10-07T11:46:10.476-04:00\n",
      "\t\t\tmodelHash = 9DD2FABC436FB75BAD6A3E061BE51022A79F140FC491C6CA8B8033253F43CD5F\n",
      "\t\t\tlocation = file:/local/ExternalRepositories/tribuo/tutorials/./fm-mnist.onnx\n",
      "\t\t)\n",
      "\ttrained-at = 2022-10-07T12:03:57.349886186-04:00\n",
      "\tinstance-values = Map{\n",
      "\t\tmodel-domain=org.tribuo.tutorials.onnxexport.fm\n",
      "\t\tmodel-graphname=FMClassificationModel\n",
      "\t\tmodel-description=factorization-machine-model - Model(class-name=org.tribuo.classification.sgd.fm.FMClassificationModel,dataset=Dataset(class-name=org.tribuo.MutableDataset,datasource=DataSource(class-name=org.tribuo.datasource.IDXDataSource,outputPath=/local/ExternalRepositories/tribuo/tutorials/train-labels-idx1-ubyte.gz,outputFactory=OutputFactory(class-name=org.tribuo.classification.LabelFactory),featuresPath=/local/ExternalRepositories/tribuo/tutorials/train-images-idx3-ubyte.gz,features-file-modified-time=2000-07-21T14:20:24-04:00,output-resource-hash=SHA-256[3552534A0A558BBED6AED32B30C495CCA23D567EC52CAC8BE1A0730E8010255C],datasource-creation-time=2022-10-07T11:45:53.253680-04:00,output-file-modified-time=2000-07-21T14:20:27-04:00,idx-feature-type=UBYTE,features-resource-hash=SHA-256[440FCABF73CC546FA21475E81EA370265605F56BE210A4024D2CA8F203523609],host-short-name=DataSource),transformations=[],is-sequence=false,is-dense=false,num-examples=60000,num-features=717,num-outputs=10,tribuo-version=4.3.0),trainer=Trainer(class-name=org.tribuo.classification.sgd.fm.FMClassificationTrainer,seed=12345,variance=0.1,minibatchSize=1,factorizedDimSize=6,shuffle=true,epochs=5,optimiser=StochasticGradientOptimiser(class-name=org.tribuo.math.optimisers.AdaGrad,epsilon=0.1,initialLearningRate=0.1,initialValue=0.0,host-short-name=StochasticGradientOptimiser),loggingInterval=30000,objective=LabelObjective(class-name=org.tribuo.classification.sgd.objectives.LogMulticlass,host-short-name=LabelObjective),tribuo-version=4.3.0,train-invocation-count=0,is-sequence=false,host-short-name=Trainer),trained-at=2022-10-07T11:46:09.759423-04:00,instance-values={},tribuo-version=4.3.0,java-version=12,os-name=Linux,os-arch=amd64)\n",
      "\t\tmodel-producer=Tribuo\n",
      "\t\tmodel-version=0\n",
      "\t\tinput-name=input\n",
      "\t}\n",
      "\ttribuo-version = 4.3.0\n",
      "\tjava-version = 17.0.4.1\n",
      "\tos-name = Linux\n",
      "\tos-arch = amd64\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "System.out.println(ProvenanceUtil.formattedProvenanceString(onnx.getProvenance()));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `ONNXExternalModel` provenance has a lot of placeholders in it, as you might expect given the information is not always present in ONNX files.\n",
    "\n",
    "We can load the Tribuo model provenance using `getTribuoProvenance()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FMClassificationModel(\n",
      "\tclass-name = org.tribuo.classification.sgd.fm.FMClassificationModel\n",
      "\tdataset = MutableDataset(\n",
      "\t\t\tclass-name = org.tribuo.MutableDataset\n",
      "\t\t\tdatasource = IDXDataSource(\n",
      "\t\t\t\t\tclass-name = org.tribuo.datasource.IDXDataSource\n",
      "\t\t\t\t\toutputFactory = LabelFactory(\n",
      "\t\t\t\t\t\t\tclass-name = org.tribuo.classification.LabelFactory\n",
      "\t\t\t\t\t\t)\n",
      "\t\t\t\t\toutputPath = /local/ExternalRepositories/tribuo/tutorials/train-labels-idx1-ubyte.gz\n",
      "\t\t\t\t\tfeaturesPath = /local/ExternalRepositories/tribuo/tutorials/train-images-idx3-ubyte.gz\n",
      "\t\t\t\t\tfeatures-file-modified-time = 2000-07-21T14:20:24-04:00\n",
      "\t\t\t\t\toutput-resource-hash = 3552534A0A558BBED6AED32B30C495CCA23D567EC52CAC8BE1A0730E8010255C\n",
      "\t\t\t\t\tdatasource-creation-time = 2022-10-07T11:45:53.253680-04:00\n",
      "\t\t\t\t\toutput-file-modified-time = 2000-07-21T14:20:27-04:00\n",
      "\t\t\t\t\tidx-feature-type = UBYTE\n",
      "\t\t\t\t\tfeatures-resource-hash = 440FCABF73CC546FA21475E81EA370265605F56BE210A4024D2CA8F203523609\n",
      "\t\t\t\t\thost-short-name = DataSource\n",
      "\t\t\t\t)\n",
      "\t\t\ttransformations = List[]\n",
      "\t\t\tis-sequence = false\n",
      "\t\t\tis-dense = false\n",
      "\t\t\tnum-examples = 60000\n",
      "\t\t\tnum-features = 717\n",
      "\t\t\tnum-outputs = 10\n",
      "\t\t\ttribuo-version = 4.3.0\n",
      "\t\t)\n",
      "\ttrainer = FMClassificationTrainer(\n",
      "\t\t\tclass-name = org.tribuo.classification.sgd.fm.FMClassificationTrainer\n",
      "\t\t\tseed = 12345\n",
      "\t\t\tvariance = 0.1\n",
      "\t\t\tminibatchSize = 1\n",
      "\t\t\tfactorizedDimSize = 6\n",
      "\t\t\tshuffle = true\n",
      "\t\t\tepochs = 5\n",
      "\t\t\toptimiser = AdaGrad(\n",
      "\t\t\t\t\tclass-name = org.tribuo.math.optimisers.AdaGrad\n",
      "\t\t\t\t\tepsilon = 0.1\n",
      "\t\t\t\t\tinitialLearningRate = 0.1\n",
      "\t\t\t\t\tinitialValue = 0.0\n",
      "\t\t\t\t\thost-short-name = StochasticGradientOptimiser\n",
      "\t\t\t\t)\n",
      "\t\t\tloggingInterval = 30000\n",
      "\t\t\tobjective = LogMulticlass(\n",
      "\t\t\t\t\tclass-name = org.tribuo.classification.sgd.objectives.LogMulticlass\n",
      "\t\t\t\t\thost-short-name = LabelObjective\n",
      "\t\t\t\t)\n",
      "\t\t\ttribuo-version = 4.3.0\n",
      "\t\t\ttrain-invocation-count = 0\n",
      "\t\t\tis-sequence = false\n",
      "\t\t\thost-short-name = Trainer\n",
      "\t\t)\n",
      "\ttrained-at = 2022-10-07T11:46:09.759423-04:00\n",
      "\tinstance-values = Map{}\n",
      "\ttribuo-version = 4.3.0\n",
      "\tjava-version = 12\n",
      "\tos-name = Linux\n",
      "\tos-arch = amd64\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "var tribuoProvenance = onnx.getTribuoProvenance().get();\n",
    "System.out.println(ProvenanceUtil.formattedProvenanceString(tribuoProvenance));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From this provenance we can see that the model is a factorization machine running on MNIST (as expected). So now we can build a `ReproUtil` and rebuild the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "var mnistRepro = new ReproUtil<>(tribuoProvenance,Label.class);\n",
    "\n",
    "var reproducedMNISTModel = mnistRepro.reproduceFromProvenance();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can diff the two provenances:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"dataset\" : {\n",
      "    \"datasource\" : {\n",
      "      \"datasource-creation-time\" : {\n",
      "        \"original\" : \"2022-10-07T11:45:53.253680-04:00\",\n",
      "        \"reproduced\" : \"2022-10-07T12:04:03.138366189-04:00\"\n",
      "      }\n",
      "    }\n",
      "  },\n",
      "  \"java-version\" : {\n",
      "    \"original\" : \"12\",\n",
      "    \"reproduced\" : \"17.0.4.1\"\n",
      "  },\n",
      "  \"trained-at\" : {\n",
      "    \"original\" : \"2022-10-07T11:46:09.759423-04:00\",\n",
      "    \"reproduced\" : \"2022-10-07T12:04:15.478400652-04:00\"\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "System.out.println(ReproUtil.diffProvenance(tribuoProvenance, reproducedMNISTModel.getProvenance()));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As before, it's not very interesting as we're using the same files and so only the creation timestamps are differing. Checking the model weights is tricky with an ONNX model, so we can instead check that the predictions are the same (though Tribuo computes in doubles and ONNX Runtime uses floats so the answers are slightly different). We'll borrow the `checkPredictions` function from the ONNX export tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "public boolean checkPredictions(List<Prediction<Label>> nativePredictions, List<Prediction<Label>> onnxPredictions, double delta) {\n",
    "    for (int i = 0; i < nativePredictions.size(); i++) {\n",
    "        Prediction<Label> tribuo = nativePredictions.get(i);\n",
    "        Prediction<Label> external = onnxPredictions.get(i);\n",
    "        // Check the predicted label\n",
    "        if (!tribuo.getOutput().getLabel().equals(external.getOutput().getLabel())) {\n",
    "            System.out.println(\"At index \" + i + \" predictions are not equal - \"\n",
    "                    + tribuo.getOutput().getLabel() + \" and \"\n",
    "                    + external.getOutput().getLabel());\n",
    "            return false;\n",
    "        }\n",
    "        // Check the maximum score\n",
    "        if (Math.abs(tribuo.getOutput().getScore() - external.getOutput().getScore()) > delta) {\n",
    "            System.out.println(\"At index \" + i + \" predictions are not equal - \"\n",
    "                    + tribuo.getOutput() + \" and \"\n",
    "                    + external.getOutput());\n",
    "            return false;\n",
    "        }\n",
    "        // Check the score distribution\n",
    "        for (Map.Entry<String, Label> l : tribuo.getOutputScores().entrySet()) {\n",
    "            Label other = external.getOutputScores().get(l.getKey());\n",
    "            if (other == null) {\n",
    "                System.out.println(\"At index \" + i + \" failed to find label \" + l.getKey() + \" in ORT prediction.\");\n",
    "                return false;\n",
    "            } else {\n",
    "                if (Math.abs(l.getValue().getScore() - other.getScore()) > delta) {\n",
    "                    System.out.println(\"At index \" + i + \" predictions are not equal - \"\n",
    "                            + tribuo.getOutputScores() + \" and \"\n",
    "                            + external.getOutputScores());\n",
    "                    return false;\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "    return true;\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can make predictions from both models and compare the outputs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predictions are equal = true\n"
     ]
    }
   ],
   "source": [
    "var onnxPredictions = onnx.predict(mnistTest);\n",
    "var reproducedPredictions = reproducedMNISTModel.predict(mnistTest);\n",
    "\n",
    "System.out.println(\"Predictions are equal = \" + checkPredictions(reproducedPredictions,onnxPredictions,1e-5));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Working with provenance diffs\n",
    "\n",
    "We can use the provenance diff methods to compute diffs for unrelated models too. We're going to train a logistic regression on MNIST and compare the model provenance against the ONNX factorization machine we just used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"class-name\" : {\n",
      "    \"original\" : \"org.tribuo.classification.sgd.fm.FMClassificationModel\",\n",
      "    \"reproduced\" : \"org.tribuo.classification.sgd.linear.LinearSGDModel\"\n",
      "  },\n",
      "  \"dataset\" : {\n",
      "    \"datasource\" : {\n",
      "      \"datasource-creation-time\" : {\n",
      "        \"original\" : \"2022-10-07T11:45:53.253680-04:00\",\n",
      "        \"reproduced\" : \"2022-10-07T12:03:56.006018468-04:00\"\n",
      "      }\n",
      "    }\n",
      "  },\n",
      "  \"java-version\" : {\n",
      "    \"original\" : \"12\",\n",
      "    \"reproduced\" : \"17.0.4.1\"\n",
      "  },\n",
      "  \"trained-at\" : {\n",
      "    \"original\" : \"2022-10-07T11:46:09.759423-04:00\",\n",
      "    \"reproduced\" : \"2022-10-07T12:04:24.453627627-04:00\"\n",
      "  },\n",
      "  \"trainer\" : {\n",
      "    \"class-name\" : {\n",
      "      \"original\" : \"org.tribuo.classification.sgd.fm.FMClassificationTrainer\",\n",
      "      \"reproduced\" : \"org.tribuo.classification.sgd.linear.LogisticRegressionTrainer\"\n",
      "    },\n",
      "    \"loggingInterval\" : {\n",
      "      \"original\" : \"30000\",\n",
      "      \"reproduced\" : \"1000\"\n",
      "    },\n",
      "    \"optimiser\" : {\n",
      "      \"initialLearningRate\" : {\n",
      "        \"original\" : \"0.1\",\n",
      "        \"reproduced\" : \"1.0\"\n",
      "      }\n",
      "    },\n",
      "    \"factorizedDimSize\" : {\n",
      "      \"original\" : \"6\"\n",
      "    },\n",
      "    \"variance\" : {\n",
      "      \"original\" : \"0.1\"\n",
      "    }\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "var lrTrainer = new LogisticRegressionTrainer();\n",
    "var lrModel = lrTrainer.train(mnistTrain);\n",
    "\n",
    "System.out.println(ReproUtil.diffProvenance(tribuoProvenance, lrModel.getProvenance()));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This diff is longer than the others we've seen, as expected for two different models with different trainers. As expected the dataset section is mostly empty as both models are trained on an unmodified MNIST training set. The `FMClassificationTrainer` and `LogisticRegressionTrainer` show more differences, but as both are SGD based models there are many common fields. They share fields like a loss function (both used `LogMulticlass`), a gradient optimiser (both used `AdaGrad`), the number of training epochs, and the minibatch size. They used different learning rates (which do appear in the diff under `optimiser`) and the factorization machine also has a few extra parameters not found in the logistic regression, `factorizedDimSize` and `variance`, which are reported as just having an `original` value, meaning they are only found in the first provenance and not the second.\n",
    "\n",
    "The current diff format is JSON, and designed to be easily human readable. We left designing a navigable diff object which is easily inspectable from code to future work once we have a better understanding of how people want to use the generated diffs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "We showed how to load in Tribuo models and reproduce them using our automated reproducibility system. The system executes the same computations as the original training, which in most cases results in an identical model. We have noted that there are some differences between gradient descent based models that are trained on ARM and x86 architectures due to underlying differences in the JVM, but otherwise the reproductions are exact. Over time we plan to expand this reproducibility system into a full experimental framework allowing models to be rebuilt using different datasets, data transformations or training hyperparameters holding all other parameters constant."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "17.0.4.1+1-LTS-2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
